#! /bin/python3
################################################################################
# Copyright 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################

import argparse
import os
import sys
from tempfile import NamedTemporaryFile

from matmul import sampler as matmul_sampler
from matmul import primitive as matmul


def log(output):
    print("synthdnn: " + output)


def error(output):
    print("synthdnn: error: " + output)
    exit(1)


def write_batch_file(batch_file, samples):
    batch_file.write("#### Auto-generated by synthdnn\n")
    batch_file.write(f"#### python3 synthdnn.py {' '.join(sys.argv[1:])}\n\n")
    for s in samples:
        batch_file.write(f"--reset {s.benchdnn_str()}\n")
    batch_file.flush()


def setup_collect_subparser(subparsers):
    collect_parser = subparsers.add_parser(
        "collect", help="call with -h for information"
    )
    collect_parser.add_argument(
        "--subprogram_main", default=collect_main, help=argparse.SUPPRESS
    )

    collect_parser.add_argument(
        "-b",
        "--batch-file",
        required=True,
        help="batch file used for the operation",
    )

    # Interface with benchdnn
    collect_parser.add_argument(
        "benchdnn", help="path to benchdnn executable"
    )
    collect_parser.add_argument(
        "--engine", default="cpu", help="engine used for benchdnn execution"
    )
    collect_parser.add_argument(
        "--impl",
        default=None,
        help="implementation to use in benchdnn execution",
    )
    collect_parser.add_argument(
        "--skip-impl",
        default=None,
        help="implementation to skip in benchdnn execution",
    )
    collect_parser.add_argument(
        "--collect",
        default="corr",
        help="benchdnn collection type, can be one of [corr, perf]",
    )
    collect_parser.add_argument("-n", "--name", default="", help="sample name")


def get_optional_args(args):
    optional_args = []
    if args.impl:
        optional_args.append(f"--impl={args.impl}")
    if args.skip_impl:
        optional_args.append(f"--skip-impl={args.skip_impl}")

    if len(optional_args) > 0:
        return " ".join(optional_args) + " "

    return ""


def collect_main(args):
    # args.benchdnn may be a list depending on command line setup
    benchdnn = args.benchdnn
    if type(benchdnn) is list:
        benchdnn = benchdnn[0]

    if not os.path.exists(benchdnn):
        error(f"cannot execute {benchdnn}, no such file exists")

    if args.collect == "corr":
        benchdnn_args = f"--engine={args.engine} --matmul --mode-modifier=P {get_optional_args(args)}"
    elif args.collect == "perf":
        benchdnn_args = f"--engine={args.engine} --matmul --mode=F --cold-cache=all --perf-template=sample,{args.name},%prb%,%0Gflops%,%0Gbw% --memory-kind=usm_device --attr-scratchpad=user {get_optional_args(args)}"
        if args.name.find(",") != -1:
            error(f"sample name {args.name} contains invalid character: ,")
    else:
        error(f"unknown collection method {args.collect}")

    cmd = f"{benchdnn} {benchdnn_args} --batch={args.batch_file}"
    log(f"executing: {cmd}")
    ret = os.system(cmd)
    log("execution complete")
    if ret != 0:
        error(f"execution of {cmd} failed with return code {ret}")


def setup_matmul_subparser(subparsers):
    matmul_parser = subparsers.add_parser(
        "matmul", help="call with -h for information"
    )
    matmul_parser.add_argument(
        "--subprogram_main", default=matmul_main, help=argparse.SUPPRESS
    )

    matmul_parser.add_argument(
        "-b",
        "--batch-file",
        default=None,
        help="batch file to write results to",
    )

    # Sampler Arguments
    matmul_parser.add_argument(
        "-l",
        "--layouts",
        default="all",
        help='stag:wtag:dtag, comma separated list of layouts or "all" for every supported layout',
    )
    matmul_parser.add_argument(
        "-m",
        "--iter-mode",
        default="zip",
        help="iteration mode, must be one of zip or product",
    )
    matmul_parser.add_argument(
        "-r",
        "--region",
        default="(1,1,1,1):(8,8192,8192,8192):(1,1,1,1)",
        help="([b_min,]m_min,n_min,k_min):([b_max,]m_max,n_max,k_max):([b_align,]m_align,n_align,k_align)",
    )
    matmul_parser.add_argument(
        "-s", "--samples", default=1000, help="number of samples to collect"
    )
    matmul_parser.add_argument(
        "-t",
        "--types",
        default="*",
        help='dt:dt:dt(optional fpmath-mode), comma separated list of type configurations. "%%N" will match the Nth given data type. Giving a single data type instead of 3 will match any supported configurations using that type.',
    )


def matmul_main(args):
    batch_file = (
        open(args.batch_file, "w+t") if args.batch_file is not None else None
    )

    region = matmul_sampler.Region(args.region)
    types = matmul.Types(args.types)
    layouts = matmul.Layouts(args.layouts, region.ndims - 1)
    samples = matmul_sampler.Sampler(
        int(args.samples), args.iter_mode, types, layouts, region
    )
    if batch_file:
        log(f"generating batch file: {args.batch_file}")
        write_batch_file(batch_file, samples)
        log(f"generation complete")
    else:
        write_batch_file(sys.stdout, samples)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    subparsers = parser.add_subparsers(
        help="primitive targeted for data collection"
    )
    setup_collect_subparser(subparsers)
    setup_matmul_subparser(subparsers)
    args = parser.parse_args()
    args.subprogram_main(args)
